{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to save and restore a trained model\n",
    "\n",
    "After scrolling through the posts of [reddit.com/r/learnmachinelearning](https://www.reddit.com/r/learnmachinelearning/), I've realized that the major bottlenecks of a machine learning project occur in the data input pipeline and in the final stage of the model, where you have to save the model and make predictions on new data.\n",
    "So I thought that it would be useful to make a simple and straightforward tutorial to show you how you could save and restore a model that you have built with Tensorflow Eager.\n",
    "\n",
    "### Tutorial flowchart\n",
    "----\n",
    "\n",
    "![img](tutorials_graphics/save_restore_model.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import here useful libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import TensorFlow and TensorFlow Eager\n",
    "import tensorflow as tf\n",
    "import tensorflow.contrib.eager as tfe\n",
    "\n",
    "# Import function to generate toy classication problem\n",
    "from sklearn.datasets import make_moons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable eager mode. Once activated it cannot be reversed! Run just once.\n",
    "tfe.enable_eager_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part I: Build a simple neural network model for binary classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class simple_nn(tf.keras.Model):\n",
    "    def __init__(self):\n",
    "        super(simple_nn, self).__init__()\n",
    "        \"\"\" Define here the layers used during the forward-pass \n",
    "            of the neural network.\n",
    "        \"\"\"   \n",
    "        # Hidden layer.\n",
    "        self.dense_layer = tf.layers.Dense(10, activation=tf.nn.relu)\n",
    "        # Output layer. No activation.\n",
    "        self.output_layer = tf.layers.Dense(2, activation=None)\n",
    "    \n",
    "    def predict(self, input_data):\n",
    "        \"\"\" Runs a forward-pass through the network.     \n",
    "            Args:\n",
    "                input_data: 2D tensor of shape (n_samples, n_features).   \n",
    "            Returns:\n",
    "                logits: unnormalized predictions.\n",
    "        \"\"\"\n",
    "        hidden_activations = self.dense_layer(input_data)\n",
    "        logits = self.output_layer(hidden_activations)\n",
    "        return logits\n",
    "    \n",
    "    def loss_fn(self, input_data, target):\n",
    "        \"\"\" Defines the loss function used during \n",
    "            training.         \n",
    "        \"\"\"\n",
    "        logits = self.predict(input_data)\n",
    "        loss = tf.losses.sparse_softmax_cross_entropy(labels=target, logits=logits)\n",
    "        return loss\n",
    "    \n",
    "    def grads_fn(self, input_data, target):\n",
    "        \"\"\" Dynamically computes the gradients of the loss value\n",
    "            with respect to the parameters of the model, in each\n",
    "            forward pass.\n",
    "        \"\"\"\n",
    "        with tfe.GradientTape() as tape:\n",
    "            loss = self.loss_fn(input_data, target)\n",
    "        return tape.gradient(loss, self.variables)\n",
    "    \n",
    "    def fit(self, input_data, target, optimizer, num_epochs=500, verbose=50):\n",
    "        \"\"\" Function to train the model, using the selected optimizer and\n",
    "            for the desired number of epochs.\n",
    "        \"\"\"\n",
    "        for i in range(num_epochs):\n",
    "            grads = self.grads_fn(input_data, target)\n",
    "            optimizer.apply_gradients(zip(grads, self.variables))\n",
    "            if (i==0) | ((i+1)%verbose==0):\n",
    "                print('Loss at epoch %d: %f' %(i+1, self.loss_fn(input_data, target).numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part II: Train model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate toy dataset for classification\n",
    "# X is a matrix of n_samples x n_features and represents the input features\n",
    "# y is a vector with length n_samples and represents our targets\n",
    "X, y = make_moons(n_samples=100, noise=0.1, random_state=2018)\n",
    "X_train, y_train = tf.constant(X[:80,:]), tf.constant(y[:80])\n",
    "X_test, y_test = tf.constant(X[80:,:]), tf.constant(y[80:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss at epoch 1: 0.658276\n",
      "Loss at epoch 50: 0.302146\n",
      "Loss at epoch 100: 0.268594\n",
      "Loss at epoch 150: 0.247425\n",
      "Loss at epoch 200: 0.229143\n",
      "Loss at epoch 250: 0.197839\n",
      "Loss at epoch 300: 0.143365\n",
      "Loss at epoch 350: 0.098039\n",
      "Loss at epoch 400: 0.070781\n",
      "Loss at epoch 450: 0.053753\n",
      "Loss at epoch 500: 0.042401\n"
     ]
    }
   ],
   "source": [
    "optimizer = tf.train.GradientDescentOptimizer(5e-1)\n",
    "model = simple_nn()\n",
    "model.fit(X_train, y_train, optimizer, num_epochs=500, verbose=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part III: Save trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify checkpoint directory\n",
    "checkpoint_directory = 'models_checkpoints/SimpleNN/'\n",
    "# Create model checkpoint\n",
    "checkpoint = tfe.Checkpoint(optimizer=optimizer,\n",
    "                            model=model,\n",
    "                            optimizer_step=tf.train.get_or_create_global_step())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'models_checkpoints/SimpleNN/-1'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Save trained model\n",
    "checkpoint.save(file_prefix=checkpoint_directory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part IV: Restore trained model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reinitialize model instance\n",
    "model = simple_nn()\n",
    "optimizer = tf.train.GradientDescentOptimizer(5e-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify checkpoint directory\n",
    "checkpoint_directory = 'models_checkpoints/SimpleNN/'\n",
    "# Create model checkpoint\n",
    "checkpoint = tfe.Checkpoint(optimizer=optimizer,\n",
    "                            model=model,\n",
    "                            optimizer_step=tf.train.get_or_create_global_step())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.contrib.eager.python.checkpointable_utils.CheckpointLoadStatus at 0x7fcfd47d2048>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Restore model from latest chekpoint\n",
    "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part V: Check if the model was restored correctly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss at epoch 1: 0.042220\n"
     ]
    }
   ],
   "source": [
    "model.fit(X_train, y_train, optimizer, num_epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The loss seems to be consistent with the loss we obtained in the last epoch of previous training :)!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part VI: Make predictions on new data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_test = model.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ 1.54352813 -0.83117302]\n",
      " [-1.60523365  2.82397487]\n",
      " [ 2.87589525 -1.36463485]\n",
      " [-1.39461001  2.62404279]\n",
      " [ 0.82305161 -0.55651397]\n",
      " [ 3.53674391 -2.55593046]\n",
      " [-2.97344627  3.46589599]\n",
      " [-1.69372442  2.95660466]\n",
      " [-1.43226137  2.65357974]\n",
      " [ 3.11479995 -1.31765645]\n",
      " [-0.65841567  1.60468631]\n",
      " [-2.27454367  3.60553595]\n",
      " [-1.50170912  2.74410115]\n",
      " [ 0.76261479 -0.44574208]\n",
      " [ 2.34516959 -1.6859307 ]\n",
      " [ 1.92181942 -1.63766352]\n",
      " [ 4.06047684 -3.03988941]\n",
      " [ 1.00252324 -0.78900484]\n",
      " [ 2.79802993 -2.2139734 ]\n",
      " [-1.43933035  2.68037059]], shape=(20, 2), dtype=float64)\n"
     ]
    }
   ],
   "source": [
    "print(logits_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
